import os
import shutil

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import scipy
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.sandbox.regression.gmm import IV2SLS 

def clean_folder(folder):
    """Create a new folder, or if the folder already exists,
    delete all containing files
    
    Args:
        folder (string): Path to folder
    """
    if os.path.isdir(folder):
        shutil.rmtree(folder)
    try:
        os.makedirs(folder)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def plot_IRF(df, file_name, fontsize = 14, dpi = 600, bias_coefs = False, 
        show = False, ylims = None, location_legend = 'lower right',
        bounds_provided = False, y_label = None, x_label = None):
    """Plot an estimated IRF, or the
    implied bias coefficients
    
    Args:
        df (dataframe): Dataframe containing
            estimated IRF
        file_name (string): Name of file to be saved to
        fontsize (int, optional): Font size
        dpi (int, optional): Dots per inch
        bias_coefs (bool, optional): If true, plot
            implied bias coefficients (i.e., multiply 
            IRF by -1)
        show (bool, optional): If true, show plot
        ylims (list, optional): Limits for the y-axis
        location_legend (str, optional): Location of legend
        bounds_provided (bool, optional): If true, df contains
            confidence bounds that are used in the plot
    """

    # Convert to bias coefficients, if asked
    if bias_coefs:
        df['IRF'] = -df['IRF']

    # Get the plot
    periods_ahead = df['lag']
    plt.figure()
    plt.plot(periods_ahead, df['IRF'], color = 'r')
    plt.axhline(0, color = 'k', linestyle = '--', alpha = 0.50)

    # Plot confidence bounds
    # Significance levels are hard-coded
    alpha = [0.05, 0.35]
    colors = ['b', 'r']
    for aa, cc in zip(alpha, colors):
        # Calculate critical values for
        # confidence intervals
        if not bounds_provided:
            t_crit = scipy.stats.t.ppf(1 - aa / 2, df['df_resid'])
            lb = df['IRF'] - t_crit * df['se']
            ub = df['IRF'] + t_crit * df['se']
        else:
            if bias_coefs:
                lb = -df['ub_{}'.format(aa)]
                ub = -df['lb_{}'.format(aa)]
            else:
                lb = df['lb_{}'.format(aa)]
                ub = df['ub_{}'.format(aa)]
        plt.fill_between(periods_ahead, lb, ub, 
                         color = cc, alpha = 0.15)

    # Add legend for confidence bounds
    conf_95 = mpatches.Patch(color = colors[0], alpha = 0.15,
                             label = '95\% Confidence Interval')
    conf_65 = mpatches.Patch(color = colors[1], alpha = 0.15, 
                             label = '65\% Confidence Interval')
    plt.legend(handles = [conf_65, conf_95],
               frameon = False,
               fontsize = fontsize,
               loc = location_legend)

    plt.xticks(periods_ahead)
    plt.ylabel(y_label, fontsize = fontsize)
    plt.xlabel(x_label, fontsize = fontsize)

    if ylims:
        plt.ylim(ylims)

    plt.tight_layout()
    plt.savefig("{}.png".format(file_name), dpi = dpi)
    plt.savefig("{}.pgf".format(file_name), dpi = dpi)

    # Calculate p-values
    if not bounds_provided:
        t_stats = df['IRF'] / df['se']
        p_vals = 2 * (1 - scipy.stats.t.cdf(np.abs(t_stats), df['df_resid']))
        df['p_vals'] = p_vals

    # Export data to file
    df.to_csv("{}.csv".format(file_name))
    if show:
        plt.show() 


def plot_rolling_window_estimates(df, file_name, fontsize = 14, dpi = 600,
        y_label = None, location_legend = 'lower right'):
    """Plot rolling window estimates
    
    Args:
        df (dataframe): Dataframe containing
            rolling window estimates
        file_name (string): Name of file to be saved to
        fontsize (int, optional): Font size
        dpi (int, optional): Dots per inch
        y_label (string, optional): Y-axis label
        location_legend (str, optional): Location of legend
    """
    plt.figure()
    plt.plot(df.index, df['IRF'], 'r')
    plt.axhline(0, color = 'k', linestyle = '--', alpha = 0.50)

    # Plot confidence bounds
    # Significance levels are hard-coded
    alpha = [0.05, 0.35]
    colors = ['b', 'r']
    for aa, cc in zip(alpha, colors):
        # Calculate critical values for
        # confidence intervals
        t_crit = scipy.stats.t.ppf(1 - aa / 2, df['df_resid']) 
        plt.fill_between(df.index, 
                         df['IRF'] - t_crit * df['se'], 
                         df['IRF'] + t_crit * df['se'], 
                         color = cc, alpha = 0.15)
    plt.xlim([df.index.min(), df.index.max()])
    plt.ylabel(y_label, fontsize = fontsize)
    conf_95 = mpatches.Patch(color = colors[0], alpha = 0.15,
                             label = '95\% Confidence Interval')
    conf_65 = mpatches.Patch(color = colors[1], alpha = 0.15, 
                             label = '65\% Confidence Interval')
    plt.legend(handles = [conf_65, conf_95],
               frameon = False,
               fontsize = fontsize,
               loc = location_legend)
    plt.savefig("{}.png".format(file_name), dpi = dpi)
    plt.savefig("{}.pgf".format(file_name), dpi = dpi)

    # Export data to file
    df.to_csv("{}.csv".format(file_name))


def construct_state_dep_table(IRFs, IRF_names, max_K):
    """Construct an output table
    for estimates of state dependent bias
    coefficients
    
    Args:
        IRFs (list): List containing estimated bias coefficent df's
        IRF_names (list): List containing table labels
        max_K (int): Number of estimated bias coefficients
    
    Returns:
        str: String containing LaTeX table
    
    """
    N = len(IRFs)
    table = '\\begin{tabular}{'
    table += '{}'.format('c' * (N + 1))
    table += '} \n'

    # Create table header
    table += '\\toprule \n'
    table += ' '
    for name in IRF_names:
        table += '& {} '.format(name)
    table += '\\\\ \n \\midrule'

    # Populate table
    for kk in range(1, max_K + 1):
        table += ' '
        for IRF in IRFs:
            table += '& {:.2f} '.format(IRF.loc[IRF['lag'] == kk, 'IRF'].values[0])
        table += '\\\\ \n'
        table += 'Lag {} '.format(kk)
        for IRF in IRFs:
            table += '& \\footnotesize{{({:.2f})}} '.format(IRF.loc[IRF['lag'] == kk, 'se'].values[0])
        table += '\\\\ \n'
        table += ' '
        for IRF in IRFs:
            table += '& \\footnotesize{{[{:.2f}]}} '.format(IRF.loc[IRF['lag'] == kk, 't_stat'].values[0])
        table += '\\\\ \n'
        if kk != max_K:
            table += '\\midrule'
        
    # Replace NaN's with blank space    
    table = table.replace('[nan]', ' ')
    table += '\\bottomrule \n'
    table += '\\end{tabular}'

    return table